# FR: KADID-7 NR:KADID-7
import pyiqa
import torch
from langchain.tools import tool
from PIL import Image
import torchvision.transforms as transforms
import os
from io import BytesIO
from typing import Union, Tuple
import cv2
import numpy as np
import json
# from segment_anything import sam_model_registry, SamPredictor
# from groundingdino.util.inference import load_model, load_image, predict, annotate, Model
import base64

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


import functools
import inspect

def retry_tool_call(tool_func, max_retries=2):
    if inspect.iscoroutinefunction(tool_func):  
        @functools.wraps(tool_func)
        async def async_wrapper(*args, **kwargs):
            for attempt in range(max_retries):
                try:
                    return await tool_func(*args, **kwargs)
                except Exception as e:
                    print(f"[{tool_func.__name__}] failed (attempt {attempt+1}): {e}")
            raise RuntimeError(f"[{tool_func.__name__}] failed after {max_retries} attempts.")
        return async_wrapper
    else:  
        @functools.wraps(tool_func)
        def sync_wrapper(*args, **kwargs):
            for attempt in range(max_retries):
                try:
                    return tool_func(*args, **kwargs)
                except Exception as e:
                    print(f"[{tool_func.__name__}] failed (attempt {attempt+1}): {e}")
            raise RuntimeError(f"[{tool_func.__name__}] failed after {max_retries} attempts.")
        return sync_wrapper
    
model_data_path = '/root/IQA/IQA-Agent/iqa_models_results/model_fitting_result.json'
with open(model_data_path, 'r') as f:
    model_data_params = json.load(f)
def logistic(model_name, X):
    beta = torch.tensor(model_data_params[model_name]['beta'], dtype=X.dtype, device=X.device)
    beta1, beta2, beta3, beta4, beta5 = beta
    logistic_part = 0.5 - 1.0 / (1 + torch.exp(beta2 * (X - beta3)))
    yhat = beta1 * logistic_part + beta4 * X + beta5
    # if lower_better:
    #     yhat = 5-yhat
    return yhat
class ImageIQAWrapper:
    def __init__(self, model_name: str, device: str = "cuda", description: str = None):
        self.model_name = model_name
        self.device = torch.device(device)
        self.model = pyiqa.create_metric(model_name, device=self.device)
        try:
            self.model_dtype = next(self.model.parameters()).dtype
        except StopIteration:
            self.model_dtype = torch.float32  # fallback
        self.model = self.model.to(self.device).to(self.model_dtype)
        self.description = description

    def preprocess(self, image_input: Union[str, bytes]) -> torch.Tensor:
        """
        Handles both file path and base64 input.
        """
        try:
            if isinstance(image_input, str):
                if len(image_input) > 300:  # likely base64
                    image_data = base64.b64decode(image_input)
                    img = Image.open(BytesIO(image_data)).convert("RGB")
                else:
                    if image_input.startswith("file://"):
                        image_input = image_input.replace("file://", "")
                    img = Image.open(image_input).convert("RGB")
            else:
                img = Image.open(BytesIO(image_input)).convert("RGB")
        except Exception as e:
            raise ValueError(f"[Preprocess Error] Failed to load image: {e}")

        img_tensor = transforms.ToTensor()(img).unsqueeze(0).to(self.device)
        return img_tensor.to(dtype=self.model_dtype)

    def predict(self, image_input: Union[str, bytes, Tuple[Union[str, bytes], Union[str, bytes]]]) -> torch.Tensor:
        if isinstance(image_input, tuple):
            dist_input, ref_input = image_input
            img_tensor = self.preprocess(dist_input)
            ref_tensor = self.preprocess(ref_input)
            with torch.no_grad():
                score = self.model(ref_tensor, img_tensor)
        else:
            img_tensor = self.preprocess(image_input)
            with torch.no_grad():
                score = self.model(img_tensor)

        score = logistic(self.model_name, score)
        print(f"[{self.model_name}] score: {score.item()}")
        return score.item()

    def run(self, image_input: Union[str, bytes, Tuple[Union[str, bytes], Union[str, bytes]]]):
        return self.predict(image_input)
@tool
def TopIQ_FR_tool(reference_image: str, distorted_image: str):
    """
    This is a Full-reference IQA model. 
    Best at evaluating: 
    - Blurs (lens blur, motion blur)
    - Color distortions (color diffusion, color shift, color quantization, color saturation)
    - Compression (JPEG2000 and JPEG)
    - Noise (white noise, color component noise, impulse noise, multiplicative noise, denoise artifact)
    - Brightness change (brighten, darken, mean shift)
    - Spatial distortions (jitter, non-eccentricity patch, pixelate, otsu quantization, color block)
    - Sharpness and contrast
    """   
    wrapper = ImageIQAWrapper(model_name='topiq_fr', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def AHIQ_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model. This tool has no known strengths for any specific distortion
    """    
    wrapper = ImageIQAWrapper(model_name='ahiq', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def FSIM_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model.This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='fsim', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def LPIPS_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='lpips', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def DISTS_tool(reference_image: str, distorted_image: str):
    """
    This is a Full-reference IQA model. 
    Best at evaluating: 
    - Blurs (gaussian blur)
    """   
    wrapper = ImageIQAWrapper(model_name='dists', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def WaDIQaM_FR_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='wadiqam_fr', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def PieAPP_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='pieapp', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def MS_SSIM_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='ms_ssim', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def GMSD_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='gmsd', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def SSIM_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='ssim', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def CKDN_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='ckdn', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def VIF_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='vif', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def PSNR_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='psnr', device='cuda')
    return wrapper.run((reference_image, distorted_image))

@tool
def VSI_tool(reference_image: str, distorted_image: str):
    """
    This is Full-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='vsi', device='cuda')
    return wrapper.run((reference_image, distorted_image))


@tool
def QAlign_tool(image_url: str):
    """
    This is No-reference IQA model. 
    Best at evaluating: 
    - Blurs (gaussian blur, motion blur)
    - Color distortions (color shift, color quantization, color saturation)
    - Noise (white noise, color component noise, impulse noise, multiplicative noise)
    - Brightness change (brighten, darken, mean shift)
    - Spatial distortions (jitter, otsu quantization)
    - Sharpness
    """  
    wrapper = ImageIQAWrapper(model_name='qalign', device='cuda')
    return wrapper.run(image_url)

@tool
def CLIPIQA_tool(image_url: str):
    """
    This is No-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='clipiqa+_rn50_512', device='cuda')
    return wrapper.run(image_url)

@tool
def UNIQUE_tool(image_url: str):
    """
    This is No-reference IQA model. 
    Best at evaluating: 
    - Blurs (lens blur)
    - Compression (JPEG, JPEG2000)
    - Noise (denoise artifact)
    - Spatial distortions (non-eccentricity patch, pixelate, color block)
    - Contrast
    """  
    wrapper = ImageIQAWrapper(model_name='unique', device='cuda')
    return wrapper.run(image_url)

@tool
def HyperIQA_tool(image_url: str):
    """
    This is No-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='hyperiqa', device='cuda')
    return wrapper.run(image_url)

@tool
def TReS_tool(image_url: str):
    """
    This is No-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='tres', device='cuda')
    return wrapper.run(image_url)

# @tool
# def MUSIQ_tool(image_url: str):
#     """
#     This is No-reference IQA model. The tool **did not** achieve a top-3 ranking for any distortion type.
#     """
#     description = "musiq has an overall confidence of 0.56. It **did not** achieve a top-3 ranking for any distortion type."
#     wrapper = ImageIQAWrapper(model_name='musiq', device='cuda')
#     return wrapper.run(image_url)

@tool
def WaDIQaM_NR_tool(image_url: str):
    """
    This is No-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='wadiqam_nr', device='cuda')
    return wrapper.run(image_url)

@tool
def DBCNN_tool(image_url: str):
    """
    This is No-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='dbcnn', device='cuda')
    return wrapper.run(image_url)

@tool
def ARNIQA_tool(image_url: str):
    """
    This is No-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='arniqa', device='cuda')
    return wrapper.run(image_url)

@tool
def NIMA_tool(image_url: str):
    """
    This is No-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='nima', device='cuda')
    return wrapper.run(image_url)

@tool
def BRISQUE_tool(image_url: str):
    """
    This is No-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='brisque', device='cuda')
    return wrapper.run(image_url)

@tool
def NIQE_tool(image_url: str):
    """
    This is No-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='niqe', device='cuda')
    return wrapper.run(image_url)

@tool
def MANIQA_tool(image_url: str):
    """
    This is No-reference IQA model. This tool has no known strengths for any specific distortion
    """
    wrapper = ImageIQAWrapper(model_name='maniqa', device='cuda')
    return wrapper.run(image_url)

@tool
def LIQE_mix_tool(image_url: str):
    """
    This is No-reference IQA model. 
    Best at evaluating: 
    - Color distortion (color diffusion)
    """  
    wrapper = ImageIQAWrapper(model_name='liqe_mix', device='cuda')
    return wrapper.run(image_url)
